#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Provides the high-level embeddings API for the project by orchestrating
corpus loading, vector generation/loading, pooled embedding construction for
centuries and persons, and downstream PCA scree analysis. The module exposes
convenient entry points for running reusable embedding workflows across
different fields of human activity.
'''

from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import pandas as pd

from .corpus import LinkCorpusBuilder, build_corpora_from_parquet
from .pooling import (
    LinkEmbedder,
    GroupEmbedder,
    build_century_embeddings,
    build_person_embeddings,
)
from .pca import generate_pca_scree_plot


def get_embeddings(
    parquet_path: Path,
    vecs_cent_path: Path,
    vecs_person_path: Path,
    batch_size: int = 128,
    device: Optional[str] = None,
) -> Tuple[Dict[int, np.ndarray], Dict[str, np.ndarray]]:
    """
    Get embeddings for centuries and persons from a parquet file.
    
    Parameters
    ----------
    parquet_path : Path
        Path to the input parquet file
    vecs_cent_path : Path
        Path to save/load century vectors
    vecs_person_path : Path
        Path to save/load person vectors
    batch_size : int
        Batch size for encoding
    device : Optional[str]
        Torch device string (e.g., "cuda:1" or "cpu")
    
    Returns
    -------
    century_embeddings : Dict[int, np.ndarray]
        Dictionary mapping century to embedding vector
    person_embeddings : Dict[str, np.ndarray]
        Dictionary mapping person name to embedding vector
    """
    # Build corpora
    builder = LinkCorpusBuilder(parquet_path)
    builder.load()
    links_by_century = builder.links_by_century()      # index=int
    links_by_person = builder.links_by_person()        # index=str

    # Century vectors (load or compute)
    uniq_cent, _ = LinkCorpusBuilder.build_vocab(links_by_century)
    try:
        E_cent = LinkEmbedder.load(vecs_cent_path)
        # If shape mismatch (e.g., vocab changed), recompute.
        if E_cent.shape[0] != len(uniq_cent):
            raise ValueError("Vector file does not match current vocab size.")
    except Exception:
        print(f"Computing vectors for century links...")
        embedder = LinkEmbedder(device=device)
        E_cent = embedder.encode(uniq_cent, batch_size=batch_size, normalize=True)
        LinkEmbedder.save(vecs_cent_path, E_cent)

    cent_pool = GroupEmbedder(uniq_cent, E_cent)
    century_embeddings: Dict[int, np.ndarray] = cent_pool.build(links_by_century)

    # Person vectors (load or compute)
    uniq_person, _ = LinkCorpusBuilder.build_vocab(links_by_person)
    try:
        E_person = LinkEmbedder.load(vecs_person_path)
        if E_person.shape[0] != len(uniq_person):
            raise ValueError("Vector file does not match current vocab size.")
    except Exception:
        print(f"Computing vectors for person links...")
        embedder = LinkEmbedder(device=device)
        E_person = embedder.encode(uniq_person, batch_size=batch_size, normalize=True)
        LinkEmbedder.save(vecs_person_path, E_person)

    person_pool = GroupEmbedder(uniq_person, E_person)
    person_embeddings: Dict[str, np.ndarray] = person_pool.build(links_by_person)

    return century_embeddings, person_embeddings


def get_embeddings_for_fha(
    parquet_path: Path,
    field_of_activity: str,
    base_vectors_path: Path = Path("D:/Users/Paschalis/phd/data/connecting_people2/"),
    batch_size: int = 128,
    device: Optional[str] = None,
) -> Tuple[Dict[int, np.ndarray], Dict[str, np.ndarray]]:
    """
    Get embeddings for a specific field of human activity.
    
    Parameters
    ----------
    parquet_path : Path
        Path to the parquet file containing filtered data
    field_of_activity : str
        Name of the field of human activity
    base_vectors_path : Path
        Base path for vector files
    batch_size : int
        Batch size for encoding
    device : Optional[str]
        Torch device string (e.g., "cuda:1" or "cpu")
    
    Returns
    -------
    century_embeddings : Dict[int, np.ndarray]
        Dictionary mapping century to embedding vector
    person_embeddings : Dict[str, np.ndarray]
        Dictionary mapping person name to embedding vector
    """
    field_clean = field_of_activity.replace(" ", "_").replace("&", "and").lower()
    vecs_cent_path = base_vectors_path / f"century_link_vectors_{field_clean}.npy"
    vecs_person_path = base_vectors_path / f"per_person_link_vectors_{field_clean}.npy"
    
    return get_embeddings(parquet_path, vecs_cent_path, vecs_person_path, batch_size, device)


def process_embeddings_pipeline(
    parquet_path: Path,
    field_of_activity: str,
    vectors_centuries_path: Optional[Path] = None,
    vectors_persons_path: Optional[Path] = None,
    scree_plot_path: Optional[Path] = None,
    sublinear_tf: bool = False,
    show_pca_plot: bool = True,
    device: Optional[str] = None
):
    """
    Main pipeline function to process embeddings for a specific field of human activity.
    
    Parameters
    ----------
    parquet_path : Path
        Path to the parquet file containing filtered data
    field_of_activity : str
        Name of the field of human activity
    vectors_centuries_path : Optional[Path]
        Optional path to save/load century vectors
    vectors_persons_path : Optional[Path]
        Optional path to save/load person vectors
    scree_plot_path : Optional[Path]
        Optional path to save PCA scree plot
    sublinear_tf : bool
        Whether to use sublinear TF in TF-IDF weighting
    show_pca_plot : bool
        Whether to display the PCA plot
    device : Optional[str]
        Torch device string (e.g., "cuda:1" or "cpu")
    
    Returns
    -------
    century_embeddings : Dict[int, np.ndarray]
        Dictionary of century embeddings
    person_embeddings : Dict[str, np.ndarray]
        Dictionary of person embeddings
    pve : np.ndarray
        Principal component variances
    pve_cum : np.ndarray
        Cumulative principal component variances
    """
    # Build corpora
    builder, links_by_century, links_by_person = build_corpora_from_parquet(parquet_path)
    
    # Build century embeddings
    century_embeddings, E_cent, tfidf_cent_weights = build_century_embeddings(
        links_by_century, 
        field_of_activity, 
        vectors_path=vectors_centuries_path,
        sublinear_tf=sublinear_tf,
        device=device
    )
    
    # Build person embeddings
    person_embeddings, E_person, tfidf_person_weights = build_person_embeddings(
        links_by_person, 
        field_of_activity, 
        vectors_path=vectors_persons_path,
        sublinear_tf=sublinear_tf,
        device=device
    )
    
    # Generate PCA scree plot for centuries
    pve, pve_cum = generate_pca_scree_plot(
        century_embeddings, 
        field_of_activity, 
        scree_plot_path=scree_plot_path,
        show=show_pca_plot
    )
    
    return century_embeddings, person_embeddings, pve, pve_cum
